from sklearn.metrics import confusion_matrix


def plot_confusion():

    YPRED=np.zeros((np.shape(y_pred)[0],1))
    YTRUE=np.zeros((np.shape(y_true)[0],1))
    for ii in range(len(y_test)):
        YPRED[ii]=np.nonzero(y_pred[ii,:])
        YTRUE[ii]=np.nonzero(y_true[ii,:])
    cm = confusion_matrix(YTRUE, YPRED)
    cm_df = pd.DataFrame(cm,
        index = ['Undamaged','Low Damage','High Damage'],
        columns = ['Undamaged','Low Damage','High Damage'])
    prec=[]
    recall=[]
    s=cm.sum(axis=1)
    s_row=cm.sum(axis=0)
    s_diag=0
    for kk in range(np.shape(cm)[0]):
        prec.append(100*cm[kk,kk]/s[kk])
        recall.append(100*cm[kk,kk]/s_row[kk])
        s_diag+=cm[kk,kk]
    overallaccuracy=100*s_diag/np.sum(cm)
    recall.append(overallaccuracy)
    cm_df["Recall"]=np.array(prec)
    row=pd.DataFrame(np.array(recall).reshape(-1, len(recall)),columns = ['Undamaged','Low Damage','High Damage','Recall'],index = ["Precision"])
    cm_df=cm_df.append(row)
    cm_percentage=100*cm/np.sum(cm)
    cm_percentage=np.column_stack((cm_percentage,100*cm.sum(axis=1)/np.sum(cm)))
    cm_percentage=np.vstack((cm_percentage,np.append(np.array(100*cm.sum(axis=0)/np.sum(cm)),cm_percentage[:,3].sum())))
    etichette = np.empty([4, 4], dtype="S15")
    for ii in range(np.shape(recall)[0]):
        for jj in range(np.shape(recall)[0]):
            if ii!=np.shape(recall)[0]-1 and jj!=np.shape(recall)[0]-1 :
                etichette[ii,jj] = f"{cm_df.to_numpy()[ii,jj]:.0f}\n{cm_percentage[ii,jj]:.2f}%"
            elif ii==jj and ii==np.shape(recall)[0]-1:
                etichette[ii,jj] = f"Accuracy\n{overallaccuracy:.2f}%"
            elif ii==np.shape(recall)[0]-1:
                etichette[ii,jj] = f"{recall[jj]:.2f}%"
            elif jj==np.shape(recall)[0]-1: 
                etichette[ii,jj] = f"{prec[ii]:.2f}%"
            print(etichette[ii,jj])
    etichette=etichette.astype('U15')
    dummy_cm_df=cm_df
    dummy_cm_df["Recall"]=np.sum(cm)*np.ones(np.shape(np.array(recall)))
    dummy_cm_df.loc["Precision"]=np.sum(cm)*np.ones(np.shape(np.array(recall)))
    import seaborn as sns
    #Plotting the confusion matrix
    fig,ax=fig, ax = plt.subplots(figsize=(7,6))
    sns.heatmap(dummy_cm_df, annot=etichette,fmt='',cmap='Greens_r', cbar=False)
    plt.title(f'A-60dB Confusion Matrix on Test Set (Size={np.sum(cm)})',fontsize=16,fontweight="bold")
    plt.ylabel('True Values',fontsize=14,fontweight="bold")
    plt.xlabel('Predicted Values',fontsize=14,fontweight="bold")
    plt.setp(ax.get_yticklabels(), va="center",rotation=90)
    fig.tight_layout()

    plt.savefig('confusion matrix.png')
    plt.show()